{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)\n",
    "\n",
    "<img src=\"img/cnn_seq.png\" width=\"300\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pos\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'seq_len': 20,\n",
    "    'batch_size': 128,\n",
    "    'hidden_dim': 128,\n",
    "    'kernel_sizes': [3, 5],\n",
    "    'text_iter_step': 1,\n",
    "    'lr': {'start': 5e-3, 'end': 5e-4},\n",
    "    'n_epoch': 1,\n",
    "    'display_step': 50,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_test_seq(*args):\n",
    "    return [np.reshape(x[:(len(x)-len(x)%params['seq_len'])],\n",
    "        [-1,params['seq_len']]) for x in args]\n",
    "\n",
    "def iter_seq(x):\n",
    "    return np.array([x[i: i+params['seq_len']] for i in range(\n",
    "        0, len(x)-params['seq_len'], params['text_iter_step'])])\n",
    "\n",
    "def to_train_seq(*args):\n",
    "    return [iter_seq(x) for x in args]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocab Size: 19124 | x_train: 211727 | x_test: 47377\n"
     ]
    }
   ],
   "source": [
    "def pipeline_train(X, y, sess):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((X, y))\n",
    "    dataset = dataset.shuffle(len(X)).batch(params['batch_size'])\n",
    "    iterator = dataset.make_initializable_iterator()\n",
    "    X_ph = tf.placeholder(tf.int32, [None, params['seq_len']])\n",
    "    y_ph = tf.placeholder(tf.int32, [None, params['seq_len']])\n",
    "    init_dict = {X_ph: X, y_ph: y}\n",
    "    sess.run(iterator.initializer, init_dict)\n",
    "    return iterator, init_dict\n",
    "\n",
    "def pipeline_test(X, sess):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(X)\n",
    "    dataset = dataset.batch(params['batch_size'])\n",
    "    iterator = dataset.make_initializable_iterator()\n",
    "    X_ph = tf.placeholder(tf.int32, [None, params['seq_len']])\n",
    "    init_dict = {X_ph: X}\n",
    "    sess.run(iterator.initializer, init_dict)\n",
    "    return iterator, init_dict\n",
    "\n",
    "\n",
    "x_train, y_train, x_test, y_test, params['vocab_size'], params['n_class'], word2idx, tag2idx = pos.load_data()\n",
    "X_train, Y_train = to_train_seq(x_train, y_train)\n",
    "X_test, Y_test = to_test_seq(x_test, y_test)\n",
    "\n",
    "sess = tf.Session()\n",
    "params['lr']['steps'] = len(X_train) // params['batch_size']\n",
    "\n",
    "iter_train, init_dict_train = pipeline_train(X_train, Y_train, sess)\n",
    "iter_test, init_dict_test = pipeline_test(X_test, sess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(x, reuse, is_training):\n",
    "    with tf.variable_scope('model', reuse=reuse):\n",
    "        x = tf.contrib.layers.embed_sequence(x, params['vocab_size'], params['hidden_dim'])\n",
    "        x = tf.layers.dropout(x, 0.1, training=is_training)\n",
    "        \n",
    "        pad = tf.zeros([tf.shape(x)[0], 1, params['hidden_dim']])\n",
    "        for k_sz in params['kernel_sizes']:\n",
    "            n = (k_sz - 1) // 2\n",
    "            _x = tf.concat([pad]*n + [x] + [pad]*n, 1)\n",
    "            x += tf.layers.conv1d(_x, params['hidden_dim'], k_sz, activation=tf.nn.relu)\n",
    "\n",
    "        logits = tf.layers.dense(x, params['n_class'])\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "`NHWC` for data_format is deprecated, use `NWC` instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:98: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "ops = {}\n",
    "\n",
    "X_train_batch, y_train_batch = iter_train.get_next()\n",
    "X_test_batch = iter_test.get_next()\n",
    "\n",
    "logits_tr = forward(X_train_batch, reuse=False, is_training=True)\n",
    "logits_te = forward(X_test_batch, reuse=True, is_training=False)\n",
    "\n",
    "log_likelihood, trans_params = tf.contrib.crf.crf_log_likelihood(\n",
    "    logits_tr, y_train_batch, tf.count_nonzero(X_train_batch, 1))\n",
    "\n",
    "ops['loss'] = tf.reduce_mean(-log_likelihood)\n",
    "\n",
    "ops['global_step'] = tf.Variable(0, trainable=False)\n",
    "\n",
    "ops['lr'] = tf.train.exponential_decay(\n",
    "    params['lr']['start'], ops['global_step'], params['lr']['steps'],\n",
    "    params['lr']['end']/params['lr']['start'])\n",
    "\n",
    "ops['train'] = tf.train.AdamOptimizer(ops['lr']).minimize(\n",
    "    ops['loss'], global_step=ops['global_step'])\n",
    "\n",
    "ops['crf_decode'] = tf.contrib.crf.crf_decode(\n",
    "    logits_te, trans_params, tf.count_nonzero(X_test_batch, 1))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 | Step 1 | Loss 74.740 | LR: 0.0050\n",
      "Epoch 1 | Step 50 | Loss 3.259 | LR: 0.0047\n",
      "Epoch 1 | Step 100 | Loss 1.595 | LR: 0.0043\n",
      "Epoch 1 | Step 150 | Loss 0.938 | LR: 0.0041\n",
      "Epoch 1 | Step 200 | Loss 0.494 | LR: 0.0038\n",
      "Epoch 1 | Step 250 | Loss 0.300 | LR: 0.0035\n",
      "Epoch 1 | Step 300 | Loss 0.412 | LR: 0.0033\n",
      "Epoch 1 | Step 350 | Loss 0.269 | LR: 0.0031\n",
      "Epoch 1 | Step 400 | Loss 0.187 | LR: 0.0029\n",
      "Epoch 1 | Step 450 | Loss 0.208 | LR: 0.0027\n",
      "Epoch 1 | Step 500 | Loss 0.174 | LR: 0.0025\n",
      "Epoch 1 | Step 550 | Loss 0.138 | LR: 0.0023\n",
      "Epoch 1 | Step 600 | Loss 0.157 | LR: 0.0022\n",
      "Epoch 1 | Step 650 | Loss 0.076 | LR: 0.0020\n",
      "Epoch 1 | Step 700 | Loss 0.242 | LR: 0.0019\n",
      "Epoch 1 | Step 750 | Loss 0.157 | LR: 0.0018\n",
      "Epoch 1 | Step 800 | Loss 0.146 | LR: 0.0016\n",
      "Epoch 1 | Step 850 | Loss 0.131 | LR: 0.0015\n",
      "Epoch 1 | Step 900 | Loss 0.099 | LR: 0.0014\n",
      "Epoch 1 | Step 950 | Loss 0.131 | LR: 0.0013\n",
      "Epoch 1 | Step 1000 | Loss 0.077 | LR: 0.0012\n",
      "Epoch 1 | Step 1050 | Loss 0.125 | LR: 0.0012\n",
      "Epoch 1 | Step 1100 | Loss 0.092 | LR: 0.0011\n",
      "Epoch 1 | Step 1150 | Loss 0.110 | LR: 0.0010\n",
      "Epoch 1 | Step 1200 | Loss 0.060 | LR: 0.0009\n",
      "Epoch 1 | Step 1250 | Loss 0.087 | LR: 0.0009\n",
      "Epoch 1 | Step 1300 | Loss 0.091 | LR: 0.0008\n",
      "Epoch 1 | Step 1350 | Loss 0.127 | LR: 0.0008\n",
      "Epoch 1 | Step 1400 | Loss 0.088 | LR: 0.0007\n",
      "Epoch 1 | Step 1450 | Loss 0.059 | LR: 0.0007\n",
      "Epoch 1 | Step 1500 | Loss 0.261 | LR: 0.0006\n",
      "Epoch 1 | Step 1550 | Loss 0.119 | LR: 0.0006\n",
      "Epoch 1 | Step 1600 | Loss 0.131 | LR: 0.0005\n"
     ]
    }
   ],
   "source": [
    "sess.run(tf.global_variables_initializer())\n",
    "for epoch in range(1, params['n_epoch']+1):\n",
    "    while True:\n",
    "        try:\n",
    "            _, step, loss, lr = sess.run([ops['train'],\n",
    "                                          ops['global_step'],\n",
    "                                          ops['loss'],\n",
    "                                          ops['lr']])\n",
    "        except tf.errors.OutOfRangeError:\n",
    "            break\n",
    "        else:\n",
    "            if step % params['display_step'] == 0 or step == 1:\n",
    "                print(\"Epoch %d | Step %d | Loss %.3f | LR: %.4f\" % (epoch, step, loss, lr))\n",
    "    \n",
    "    Y_pred = []\n",
    "    while True:\n",
    "        try:\n",
    "            Y_pred.append(sess.run(ops['crf_decode']))\n",
    "        except tf.errors.OutOfRangeError:\n",
    "            break\n",
    "    Y_pred = np.concatenate(Y_pred)\n",
    "    \n",
    "    if epoch != params['n_epoch']:\n",
    "        sess.run(iter_train.initializer, init_dict_train)\n",
    "        sess.run(iter_test.initializer, init_dict_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/sklearn/metrics/classification.py:1428: UserWarning: labels size, 43, does not match size of target_names, 45\n",
      "  .format(len(labels), len(target_names))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "      <pad>       0.89      0.95      0.92      6639\n",
      "         NN       0.99      1.00      0.99      5070\n",
      "         IN       1.00      1.00      1.00      4020\n",
      "         DT       0.96      0.93      0.95       912\n",
      "        VBZ       0.97      0.94      0.96      1354\n",
      "         RB       0.91      0.90      0.90      1103\n",
      "        VBN       1.00      1.00      1.00      1177\n",
      "         TO       0.92      0.94      0.93      1269\n",
      "         VB       0.83      0.94      0.88      2962\n",
      "         JJ       0.97      0.89      0.93      3034\n",
      "        NNS       0.97      0.89      0.93      4803\n",
      "        NNP       1.00      1.00      1.00      2389\n",
      "          ,       1.00      1.00      1.00      1214\n",
      "         CC       1.00      1.00      1.00       433\n",
      "        POS       1.00      1.00      1.00      1974\n",
      "          .       0.90      0.92      0.91       539\n",
      "        VBP       0.92      0.87      0.89       727\n",
      "        VBG       1.00      1.00      1.00       421\n",
      "       PRP$       0.99      0.96      0.98      1918\n",
      "         CD       1.00      1.00      1.00       323\n",
      "         ``       1.00      1.00      1.00       316\n",
      "         ''       0.97      0.94      0.95      1679\n",
      "        VBD       0.98      1.00      0.99        48\n",
      "         EX       1.00      0.99      1.00       470\n",
      "         MD       1.00      1.00      1.00        11\n",
      "          #       1.00      1.00      1.00        77\n",
      "          (       1.00      1.00      1.00       384\n",
      "          $       1.00      1.00      1.00        77\n",
      "          )       0.90      0.63      0.74       130\n",
      "       NNPS       1.00      1.00      1.00       814\n",
      "        PRP       1.00      0.94      0.97        77\n",
      "        JJS       1.00      1.00      1.00       110\n",
      "         WP       0.92      0.86      0.89        70\n",
      "        RBR       0.96      0.95      0.95       202\n",
      "        JJR       0.99      0.93      0.96       202\n",
      "        WDT       1.00      0.99      0.99        93\n",
      "        WRB       1.00      0.98      0.99        49\n",
      "        RBS       1.00      0.80      0.89        10\n",
      "        PDT       0.80      0.67      0.73        12\n",
      "         RP       1.00      1.00      1.00       238\n",
      "          :       1.00      0.75      0.86         4\n",
      "         FW       1.00      1.00      1.00         4\n",
      "        WP$       1.00      0.50      0.67         2\n",
      "\n",
      "avg / total       0.96      0.96      0.96     47360\n",
      "\n",
      "I love you\n",
      "PRP VBP PRP\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(Y_test.ravel(), Y_pred.ravel(), target_names=tag2idx.keys()))\n",
    "\n",
    "sample = ['I', 'love', 'you']\n",
    "x = np.atleast_2d([word2idx[w] for w in sample] + [0]*(params['seq_len']-len(sample)))\n",
    "\n",
    "ph = tf.placeholder(tf.int32, [None, params['seq_len']])\n",
    "logits = forward(ph, reuse=True, is_training=False)\n",
    "infer_op = tf.contrib.crf.crf_decode(logits, trans_params, tf.count_nonzero(ph, 1))[0]\n",
    "idx2tag = {idx : tag for tag, idx in tag2idx.items()}\n",
    "\n",
    "x = sess.run(infer_op, {ph: x})[0][:len(sample)]\n",
    "print(' '.join(sample))\n",
    "print(' '.join([idx2tag[idx] for idx in x if idx != 0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
